{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/derek/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n",
      "/home/derek/anaconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: numpy.dtype size changed, may indicate binary incompatibility. Expected 96, got 88\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "import os\n",
    "import re\n",
    "import numpy as np\n",
    "import pickle\n",
    "from collections import Counter\n",
    "from preprocessing import CharacterIndexer, SlotIndexer, IntentIndexer\n",
    "from gensim.models import Word2Vec"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### read data\n",
    "\n",
    "the `PlayMusic` file has a bad entry so the commented cell must be run once, then the original `PlayMusic` file deleted."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # fix bad line in the PlayMusic file\n",
    "# # then manually rename\n",
    "# def decode(s, encoding=\"utf8\", errors=\"ignore\"):\n",
    "#     return s.decode(encoding=encoding, errors=errors)\n",
    "\n",
    "# raw_json_file = open('data/snips/train/train_PlayMusic_full.json', 'rb')\n",
    "# raw_json = decode(bytes(raw_json_file.read()))\n",
    "# all_str = raw_json[ raw_json.find(\"{\"): ]\n",
    "# all_obj = json.loads(all_str)\n",
    "# with open('data/snips/train/train_PlayMusic_full_fixed.json', 'w') as outfile:\n",
    "#     json.dump(all_obj, outfile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_PlayMusic_full_fixed.json PlayMusic\n",
      "train_SearchCreativeWork_full.json SearchCreativeWork\n",
      "train_BookRestaurant_full.json BookRestaurant\n",
      "train_RateBook_full.json RateBook\n",
      "train_GetWeather_full.json GetWeather\n",
      "train_AddToPlaylist_full.json AddToPlaylist\n",
      "train_SearchScreeningEvent_full.json SearchScreeningEvent\n"
     ]
    }
   ],
   "source": [
    "train_sents, train_tags, train_intents = [], [], []\n",
    "path = 'data/snips/train'\n",
    "for filename in os.listdir(path):\n",
    "    if 'json' in filename:\n",
    "        with open(path + '/' + filename, encoding='utf8') as json_file:\n",
    "            intent = filename.split('_')[1]\n",
    "            print(filename, intent)\n",
    "    #         try:\n",
    "            data = json.load(json_file)\n",
    "            data = data[intent]\n",
    "            for sent in data:\n",
    "                s, t = [], []\n",
    "                for dct in sent['data']:\n",
    "                    if 'entity' in dct.keys():\n",
    "                        t.append(dct['entity'])\n",
    "                        s.append(dct['text'])\n",
    "                    else:\n",
    "                        t.append(\"NONE\")\n",
    "                        s.append(dct['text'])\n",
    "                train_sents.append(s)\n",
    "                train_tags.append(t)\n",
    "                train_intents.append(intent)\n",
    "#         except UnicodeDecodeError:\n",
    "#             pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(13784, 13784, 13784)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(train_sents), len(train_tags), len(train_intents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('PlayMusic', 2000),\n",
       " ('GetWeather', 2000),\n",
       " ('BookRestaurant', 1973),\n",
       " ('SearchScreeningEvent', 1959),\n",
       " ('RateBook', 1956),\n",
       " ('SearchCreativeWork', 1954),\n",
       " ('AddToPlaylist', 1942)]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Counter(train_intents).most_common()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "validate_SearchCreativeWork.json SearchCreativeWork\n",
      "validate_AddToPlaylist.json AddToPlaylist\n",
      "validate_BookRestaurant.json BookRestaurant\n",
      "validate_SearchScreeningEvent.json SearchScreeningEvent\n",
      "validate_PlayMusic.json PlayMusic\n",
      "validate_GetWeather.json GetWeather\n",
      "validate_RateBook.json RateBook\n"
     ]
    }
   ],
   "source": [
    "val_sents, val_tags, val_intents = [], [], []\n",
    "path = 'data/snips/validate'\n",
    "for filename in os.listdir(path):\n",
    "    if 'json' in filename:\n",
    "        with open(path + '/' + filename) as json_file:\n",
    "            intent = filename.split('_')[1]\n",
    "            intent = intent.split('.')[0]\n",
    "            print(filename, intent)\n",
    "#             try:\n",
    "            data = json.load(json_file)\n",
    "            data = data[intent]\n",
    "            for sent in data:\n",
    "                s, t = [], []\n",
    "                for dct in sent['data']:\n",
    "                    if 'entity' in dct.keys():\n",
    "                        t.append(dct['entity'])\n",
    "                        s.append(dct['text'])\n",
    "                    else:\n",
    "                        t.append(\"NONE\")\n",
    "                        s.append(dct['text'])\n",
    "                val_sents.append(s)\n",
    "                val_tags.append(t)\n",
    "                val_intents.append(intent)\n",
    "#             except UnicodeDecodeError:\n",
    "#                 pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(700, 700, 700)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(val_sents), len(val_tags), len(val_intents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('SearchCreativeWork', 100),\n",
       " ('AddToPlaylist', 100),\n",
       " ('BookRestaurant', 100),\n",
       " ('SearchScreeningEvent', 100),\n",
       " ('PlayMusic', 100),\n",
       " ('GetWeather', 100),\n",
       " ('RateBook', 100)]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Counter(val_intents).most_common()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['I need to hear the ', 'song', ' ', 'Aspro Mavro', ' from ', 'Bill Szymczyk', ' on ', 'Youtube']\n",
      "['NONE', 'music_item', 'NONE', 'track', 'NONE', 'artist', 'NONE', 'service']\n",
      "PlayMusic\n"
     ]
    }
   ],
   "source": [
    "print(train_sents[0])\n",
    "print(train_tags[0])\n",
    "print(train_intents[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# preprocess sentences\n",
    "def cleanup(sentlist, taglist):\n",
    "    newsents = []\n",
    "    newtags  = []\n",
    "    for idx, sent in enumerate(sentlist):\n",
    "        newsent, newtag = [], []\n",
    "        for jdx, phrase in  enumerate(sent):\n",
    "            for c in ['.', ',', '!', '?', ]:\n",
    "                phrase = phrase.replace(c, '')\n",
    "            tt = phrase.split()\n",
    "            for kdx, t in enumerate(tt):\n",
    "                # digit replacement\n",
    "                # if t.isdigit():\n",
    "                #     newsent.append(digits.get(t, '##'))\n",
    "                # else:\n",
    "                newsent.append(t.lower())\n",
    "                newtag.append(taglist[idx][jdx])\n",
    "        newsents.append(newsent)\n",
    "        newtags.append(newtag)\n",
    "    return newsents, newtags"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_sents, train_tags = cleanup(train_sents, train_tags)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_sents, val_tags = cleanup(val_sents, val_tags)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['i', 'need', 'to', 'hear', 'the', 'song', 'aspro', 'mavro', 'from', 'bill', 'szymczyk', 'on', 'youtube']\n",
      "['NONE', 'NONE', 'NONE', 'NONE', 'NONE', 'music_item', 'track', 'track', 'NONE', 'artist', 'artist', 'NONE', 'service']\n",
      "PlayMusic\n"
     ]
    }
   ],
   "source": [
    "print(train_sents[0])\n",
    "print(train_tags[0])\n",
    "print(train_intents[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "40"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(list(set([t for s in train_tags for t in s])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('NONE', 59611),\n",
       " ('object_name', 10792),\n",
       " ('playlist', 5511),\n",
       " ('timeRange', 4787),\n",
       " ('object_type', 4216),\n",
       " ('artist', 3793),\n",
       " ('movie_name', 2794),\n",
       " ('spatial_relation', 2103),\n",
       " ('rating_value', 1957),\n",
       " ('city', 1829)]"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "Counter(list([t for s in train_tags for t in s])).most_common(10)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### train word2vec embeddings\n",
    "\n",
    "using pretrained embeddings has been found to increase performance in various NLP tasks.\n",
    "\n",
    "we could train on an external corpus such as brown but we will just use the training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "training done!\n"
     ]
    }
   ],
   "source": [
    "# train and save model\n",
    "model = Word2Vec(train_sents, size=200, min_count=1, window=3, workers=3, iter=5)\n",
    "model.save('model/snips_w2v.gensimmodel')\n",
    "print('training done!')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get model vocabulary\n",
    "vocab = dict([(k, v.index) for k, v in model.wv.vocab.items()])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[('listen', 0.997948169708252),\n",
       " ('would', 0.9847648739814758),\n",
       " ('want', 0.9830597043037415),\n",
       " ('w', 0.9820548295974731),\n",
       " ('watch', 0.9763119220733643),\n",
       " ('to', 0.9758093357086182),\n",
       " ('insult', 0.972419798374176),\n",
       " ('go', 0.9699915647506714),\n",
       " ('see', 0.9680747389793396),\n",
       " ('kill', 0.9657482504844666)]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test\n",
    "model.wv.most_similar('hear')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### encode data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "11895"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "vocab = list(set([w for s in train_sents for w in s]))\n",
    "len(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "VOCABSIZE = 10000"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "15"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "lens = [len(s) for s in train_sents]\n",
    "_MAXLEN = int(np.round(np.mean(lens) + 2*np.std(lens)))\n",
    "_MAXLEN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.9640887986070806"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum([1 if lens[i] <= _MAXLEN else 0 for i in range(len(lens))])/len(lens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "sent_indexer = CharacterIndexer(max_sent_len=15, max_word_mode='std', max_word_vocab=10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "fit(): splitting...\n",
      "fit(): max word len set to 9\n",
      "fit(): creating conversion dictionaries...\n",
      "fit(): tru word vocab: 10000\n",
      "fit(): tru char vocab: 91\n",
      "fit(): done!\n"
     ]
    }
   ],
   "source": [
    "sent_indexer.fit(train_sents, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((13784, 15), (13784, 15, 9), (700, 15), (700, 15, 9))"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# transform the sentence data\n",
    "trn_text_idx, trn_char_idx = sent_indexer.transform(train_sents)\n",
    "tst_text_idx, tst_char_idx = sent_indexer.transform(val_sents)\n",
    "trn_text_idx.shape, trn_char_idx.shape, tst_text_idx.shape, tst_char_idx.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((13784, 15, 1), (700, 15, 1))"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# transform the slot data\n",
    "slot_indexer = SlotIndexer(max_len=15)\n",
    "slot_indexer.fit(train_tags)\n",
    "trn_slot_idx = slot_indexer.transform(train_tags)\n",
    "tst_slot_idx = slot_indexer.transform(val_tags)\n",
    "trn_slot_idx.shape, tst_slot_idx.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((13784, 1), (700, 1))"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# transform the intent data\n",
    "int_indexer = IntentIndexer()\n",
    "int_indexer.fit(train_intents)\n",
    "trn_int_idx = int_indexer.transform(train_intents)\n",
    "tst_int_idx = int_indexer.transform(val_intents)\n",
    "trn_int_idx.shape, tst_int_idx.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### save data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(train_sents,   open('data/snips/train_sents.pkl', 'wb'))\n",
    "pickle.dump(train_tags,    open('data/snips/train_slots.pkl', 'wb'))\n",
    "pickle.dump(train_intents, open('data/snips/train_intents.pkl', 'wb'))\n",
    "pickle.dump(val_sents,   open('data/snips/test_sents.pkl', 'wb'))\n",
    "pickle.dump(val_tags,    open('data/snips/test_slots.pkl', 'wb'))\n",
    "pickle.dump(val_intents, open('data/snips/test_intents.pkl', 'wb'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.save('encoded/snips_x_trn_text.npy', trn_text_idx)\n",
    "np.save('encoded/snips_x_tst_text.npy', tst_text_idx)\n",
    "\n",
    "np.save('encoded/snips_x_trn_char.npy', trn_char_idx)\n",
    "np.save('encoded/snips_x_tst_char.npy', tst_char_idx)\n",
    "\n",
    "np.save('encoded/snips_y_trn_slot.npy', trn_slot_idx)\n",
    "np.save('encoded/snips_y_tst_slot.npy', tst_slot_idx)\n",
    "\n",
    "np.save('encoded/snips_y_trn_ints.npy', trn_int_idx)\n",
    "np.save('encoded/snips_y_tst_ints.npy', tst_int_idx)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(sent_indexer, open(\"encoded/snips_sent_indexer.pkl\", \"wb\"))\n",
    "pickle.dump(slot_indexer, open(\"encoded/snips_slot_indexer.pkl\", \"wb\"))\n",
    "pickle.dump(int_indexer,  open(\"encoded/snips_int_indexer.pkl\", \"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "pickle.dump(vocab,  open(\"model/snips_w2v_vocab.pkl\", \"wb\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "base"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
